import torch
import copy
import torch.nn as nn
from tqdm import tqdm
import math
import evaluate as eval_metric
from transformers.optimization import get_scheduler

def create_scheduler(args, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
    """
    Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
    passed as an argument.
    Args:
        num_training_steps (int): The number of training steps to do.
    """
    lr_scheduler = get_scheduler(
        args.lr_scheduler_type,
        optimizer=optimizer if optimizer is None else optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
        scheduler_specific_kwargs={},
    )
    return lr_scheduler


class Model():
    def __init__(self, args, model, dist):
        super(Model, self).__init__()
        self.args = args
        self.model = model
        self.dist = dist
        self.current_rank = self.args.rank_mat
        self.create_opt()
        self.past_regularization = self.regularization()


    def regularization(self):

        x = 0.0
        for (b, a) in zip(self.trainable_params_B, self.trainable_params_A):
            x += (a[math.floor(a.size(0) * self.args.gamma):a.size(0), :]).norm() *\
                 (b[:, math.floor(b.size(1) * self.args.gamma):b.size(1)]).norm()
        return x

    def create_opt(self):
        self.trainable_params_A = []
        self.trainable_params_B = []
        self.trainable_params_C = []
        num_trainable_params = 0
        all_param = 0
        for name, param in self.model.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel
            # Due to the design of 4bit linear layers from bitsandbytes
            # one needs to multiply the number of parameters by 2 to get
            # the correct number of parameters
            if param.__class__.__name__ == "Params4bit":
                num_params = num_params * 2

            all_param += num_params

            if param.requires_grad:
                if 'lora_A' in name :
                    self.trainable_params_A.append(param)
                elif 'lora_B' in name:
                    self.trainable_params_B.append(param)
                else:
                    self.trainable_params_C.append(param)
                num_trainable_params += num_params

        num_trainable_params = torch.tensor([num_trainable_params]).cuda()
        print(f'{self.args.rank}: {num_trainable_params}')
        self.dist.reduce(num_trainable_params, dst=0, op=self.dist.ReduceOp.SUM)
        print('averaging...')
        num_trainable_params_avg = int(num_trainable_params[0]/self.args.world_size)

        if self.args.rank % self.args.world_size == 0:
            print(f"{ self.args.rank}: trainable params: {num_trainable_params_avg} || all params: {all_param:,d} || trainable%: {100 * num_trainable_params_avg / all_param}")
    def start_local_steps(self):
        # print('Start local steps')
        # compute the regularization term
        self.past_regularization = self.regularization()


    def end_local_step(self):
        x = self.regularization()
        print(f'past_regularization: {self.past_regularization}')
        print(f'cur_regularization: {x}')
        i, j = 0, 0
        # prune the lora_A and lora_B if the regularization becomes smaller
        if  self.past_regularization - x >  0 and self.current_rank > self.args.rank_min:
            print('pruning')

            for name, module in self.model.named_modules():
                if isinstance(module, nn.Linear):
                    if 'lora_A' in name:
                        self.current_rank = math.floor(module.weight.size(0) * self.args.gamma)
                        new_layer =  torch.nn.Parameter(module.weight.data[:self.current_rank, :])
                        setattr(module, 'weight', new_layer)
                        self.trainable_params_A[i] = new_layer.data
                        i += 1
                    if 'lora_B' in name:
                        self.current_rank = math.floor(module.weight.size(1) * self.args.gamma)
                        new_layer =  torch.nn.Parameter(module.weight.data[:, :self.current_rank])
                        setattr(module, 'weight', new_layer)
                        self.trainable_params_B[j] = new_layer.data
                        self.current_rank = new_layer.data.size(1)
                        j += 1
        print(f'rank: {self.args.rank}, r: {self.current_rank}')


    def train(self, train_dataset, eval_dataset):
        device = torch.device(f"cuda:{self.args.gpu}")
        print(f'---------------{device}----------------')
        self.model = self.model.cuda()
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.args.batch_size, shuffle=True)
        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.args.eval_batch_size)
        optimizer_grouped_parameters = [
            {
                "params": self.trainable_params_A,
                "weight_decay": 1e-4,
                "lr": self.args.lr_B,
            },
            {
                "params": self.trainable_params_B,
                "weight_decay": 1e-4,
                "lr": self.args.lr_B,
            },
            {
                "params": self.trainable_params_C,
                "weight_decay": 1e-4,
                "lr": self.args.lr_B,
            },
        ]
        self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters)
        num_training_steps = self.args.com_rounds * self.args.com_interval
        lr_scheduler = create_scheduler(self.args, num_training_steps, self.optimizer)
        if self.args.dataset == "mnli":
            test_results_list = {"mnli-m/acc": [], "mnli-mm/acc":[]}
        else:
            test_results_list = {}
        metric = eval_metric.load("glue", self.args.dataset)

        loss_list = []
        cut_round = 0
        self.model.train()
        for epoch in range(self.args.num_epochs):
            for batch_idx, batch in enumerate(train_loader):
                if cut_round >= self.args.com_rounds:
                    break
                # for batch_idx, batch in enumerate(train_loader):
                # print(f'Rank: {args.rank}')
                input_ids = batch["input_ids"].cuda()
                attention_mask = batch["attention_mask"].cuda()
                labels = batch["label"].cuda()
                self.optimizer.zero_grad()
                loss = self.learn(input_ids, attention_mask, labels)
                loss_list.append(round(loss.item(), 4))

                torch.nn.utils.clip_grad_norm_(self.trainable_params_A, self.args.max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.trainable_params_B, self.args.max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.trainable_params_C, self.args.max_grad_norm)
                self.optimizer.step()
                lr_scheduler.step()
                # progress_bar.set_postfix({"Rank": self.args.rank, "Train loss": loss.item(), "r": self.current_rank, "lr": lr_scheduler.get_lr()})
                print(f"Rank: {self.args.rank}, Train loss: {loss.item()}, r: {self.current_rank}, lr: {lr_scheduler.get_lr()}")

                if batch_idx%self.args.com_interval == 0:
                    self.end_local_step()
                    norm_BA = [(b.detach()@a.detach()).norm() for (b, a) in zip(self.trainable_params_B, self.trainable_params_A)]
                    personal_norm_BA = copy.copy(norm_BA)
                    # compute the sum of all the clients' BA norm
                    for i in range(len(norm_BA)):
                        self.dist.reduce(norm_BA[i], dst=0, op=self.dist.ReduceOp.SUM)
                        self.dist.broadcast(norm_BA[i], src=0)
                    i, j = 0, 0
                    for n, p in self.model.named_parameters():
                        if p.requires_grad == True:
                            if 'lora_A' in n or  'lora_B' in n:
                                if 'lora_A' in n:
                                    zero_tensor = torch.zeros(self.args.rank_max-p.size(0), p.size(1), dtype=p.dtype, device=p.device)
                                    send_p = torch.cat([p, zero_tensor], dim=0)/self.args.world_size
                                    send_p *= personal_norm_BA[i]/norm_BA[i]
                                    i += 1
                                else:
                                    # print('zero_B')
                                    zero_tensor = torch.zeros(p.size(0), self.args.rank_max-p.size(1),  dtype=p.dtype, device=p.device)
                                    send_p = torch.cat([p, zero_tensor], dim=1)/self.args.world_size
                                    send_p *= personal_norm_BA[j]/norm_BA[j]
                                    j += 1

                                self.dist.reduce(send_p.data, dst=0, op=self.dist.ReduceOp.SUM)
                                self.dist.broadcast(send_p.data, src=0)

                                if 'lora_A' in n:
                                    p.data = send_p.data[:p.size(0), :]

                                else:
                                    p.data = send_p.data[:, :p.size(1)]


                            else:
                                send_p = p/self.args.world_size
                                self.dist.reduce(send_p.data, dst=0, op=self.dist.ReduceOp.SUM)
                                self.dist.broadcast(send_p.data, src=0)
                                p.data = send_p.data

                    cut_round += 1
                    self.start_local_steps()


            # Add a synchronization barrier before dist.all_gather
            print('Synchronization barrier')
            self.dist.barrier()
            eval_results = self.evaluate( eval_loader, metric)
            print(f'rank: {self.args.rank}, r:{self.current_rank}, {eval_results}')
            self.dist.barrier()
            # Convert the evaluation results to a tensor
            if self.args.dataset == "mnli":
                eval_results_tensor = torch.tensor([eval_results["accuracy"], eval_results["accuracy"]],
                                                   dtype=torch.float32).cuda()
            else:
                eval_results_tensor = torch.tensor(list(eval_results.values()), dtype=torch.float32).cuda()

            # Gather evaluation results from all clients
            if self.args.dataset == "mnli":
                eval_results_list = [torch.zeros_like(eval_results_tensor) for _ in range(self.args.world_size)]
            else:
                eval_results_list = [torch.zeros_like(eval_results_tensor) for _ in range(self.args.world_size)]

            self.dist.all_gather(eval_results_list, eval_results_tensor)


            self.dist.barrier()

            if self.args.rank % self.args.world_size == 0:
                if self.args.dataset == "mnli":
                    avg_matched_acc = torch.mean(torch.tensor([result[0] for result in eval_results_list])).item()
                    avg_mismatched_acc = torch.mean(torch.tensor([result[1] for result in eval_results_list])).item()
                    test_results_list["mnli-m/acc"].append(avg_matched_acc)
                    test_results_list["mnli-mm/acc"].append(avg_mismatched_acc)
                else:
                    avg_eval_results = {}
                    for i, key in enumerate(eval_results.keys()):
                        avg_value = round(torch.mean(torch.tensor([result[i] for result in eval_results_list])).item(), 4)
                        avg_eval_results[key] = avg_value
                        if key not in test_results_list:
                            test_results_list[key] = []
                        test_results_list[key].append(avg_value)

                print(f"Epoch {epoch + 1}/{self.args.num_epochs}:")
                print(f"Average evaluation results: {test_results_list}")
                torch.save((test_results_list, loss_list), self.args.save_path + '.pkl')
                with open(self.args.save_path + '.txt', 'w') as f:
                    f.write(str({'Exp config': str(self.args), 'Average evaluation results': str(test_results_list)}))


            if cut_round >= self.args.com_rounds:
                if self.args.rank % self.args.world_size == 0:
                    print(f"Average evaluation results: {test_results_list}")
                break

    def evaluate(self, dataloader, metric):

        norm_BA = [(b.detach() @ a.detach()).norm() for (b, a) in zip(self.trainable_params_B, self.trainable_params_A)]
        personal_norm_BA = copy.copy(norm_BA)
        # compute the sum of all the clients' BA norm
        for i in range(len(norm_BA)):
            self.dist.reduce(norm_BA[i], dst=0, op=self.dist.ReduceOp.SUM)
            self.dist.broadcast(norm_BA[i], src=0)
        i, j = 0, 0
        for n, p in self.model.named_parameters():
            if p.requires_grad == True:
                if 'lora_A' in n or 'lora_B' in n:

                    if 'lora_A' in n:
                        zero_tensor = torch.zeros(self.args.rank_max - p.size(0), p.size(1), dtype=p.dtype,
                                                  device=p.device)
                        send_p = torch.cat([p, zero_tensor], dim=0) / self.args.world_size
                        send_p *= personal_norm_BA[i] / norm_BA[i]
                        i += 1
                    else:
                        zero_tensor = torch.zeros(p.size(0), self.args.rank_max - p.size(1), dtype=p.dtype,
                                                  device=p.device)
                        send_p = torch.cat([p, zero_tensor], dim=1) / self.args.world_size
                        send_p *= personal_norm_BA[j] / norm_BA[j]
                        j += 1

                    self.dist.reduce(send_p.data, dst=0, op=self.dist.ReduceOp.SUM)
                    self.dist.broadcast(send_p.data, src=0)

                    if 'lora_A' in n:
                        p.data = send_p.data[:p.size(0), :]

                    else:
                        p.data = send_p.data[:, :p.size(1)]


                else:
                    send_p = p / self.args.world_size
                    self.dist.reduce(send_p.data, dst=0, op=self.dist.ReduceOp.SUM)
                    self.dist.broadcast(send_p.data, src=0)
                    p.data = send_p.data

        self.model.eval()
        progress_bar = tqdm(dataloader, desc="Evaluation", unit="batch")

        for batch in progress_bar:
            input_ids = batch["input_ids"].cuda()
            attention_mask = batch["attention_mask"].cuda()
            labels = batch["label"].cuda()
            outputs = self.model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)
            metric.add_batch(predictions=predictions, references=labels)

        results = metric.compute()
        return results

    def learn(self, input_ids, attention_mask, labels):
        outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss+ self.args.lamb * self.regularization()
        loss.backward()
        return loss